/*************************************************************************
 * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 * Copyright (c) 2023, Meta Platforms, Inc. and affiliates.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#ifndef TUNER_V5_H_
#define TUNER_V5_H_

// NVL domain information struct
typedef struct {
  int nNvlDomains;                    // number of NVLink domains
  int minRanksPerNvlDomain;           // minimum ranks across all NVLink domains
  int maxRanksPerNvlDomain;           // maximum ranks across all NVLink domains
} ncclNvlDomainInfo_v5_t;

#define NCCL_NUM_ALGORITHMS_V5 7 // Tree/Ring/CollNet*/PAT
#define NCCL_NUM_PROTOCOLS_V5 3 // Simple/LL/LL128
#define NCCL_NUM_HW_LINKS_V5 3
#define NCCL_NUM_COMPCAPS_V5 4
#define NCCL_NUM_TUNING_SCALES_V5 3

typedef struct {
  double baseLatencies [NCCL_NUM_ALGORITHMS_V5][NCCL_NUM_PROTOCOLS_V5];
  double hwLatencies [NCCL_NUM_HW_LINKS_V5][NCCL_NUM_ALGORITHMS_V5][NCCL_NUM_PROTOCOLS_V5];

  double llMaxBws [NCCL_NUM_COMPCAPS_V5][NCCL_NUM_TUNING_SCALES_V5];
  double perChMaxRingLL128Bws [NCCL_NUM_COMPCAPS_V5][NCCL_NUM_TUNING_SCALES_V5];
  double perChMaxTreeLL128Bws [NCCL_NUM_COMPCAPS_V5][NCCL_NUM_TUNING_SCALES_V5];
  double perChMaxTreeBws [NCCL_NUM_COMPCAPS_V5][NCCL_NUM_TUNING_SCALES_V5];
  double perChMaxNVLSTreeBws [NCCL_NUM_COMPCAPS_V5][NCCL_NUM_TUNING_SCALES_V5];


} ncclTunerConstants_v5_t;

// API to be implemented by external tuner
typedef struct {
  // Name of the tuner
  const char* name;

  // Initializes tuner states.
  // Inputs:
  //   - commId: communicator identifier
  //   - nRanks: number of ranks in current communicator. Each communicator initialize its own tuner.
  //   - nNodes: number of nodes in current communicator.
  //   - logFunction: a logFunction can be useful to integrate logging together with NCCL core.
  //   - nvlDomainInfo: NVL domain information struct
  // Outputs:
  //   - context: tuner context object
  // Input/Output:
  //   - constants: tuner constants
  ncclResult_t (*init)(void** ctx, uint64_t commId, size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction,
                      ncclNvlDomainInfo_v5_t* nvlDomainInfo, ncclTunerConstants_v5_t* constants);

  // Gets info (algo, protocol, number of ctas and threads) for a given collective.
  // Inputs:
  //   - context: tuner context object
  //   - collType: collective type , e.g., allreduce, allgather…
  //   - nBytes: collective size in bytes
  //   - numPipeOps: number of operations in the group
  //   - numAlgo: number of algorithms in collCostTable
  //   - numProto: number of protocols in collCostTable
  //   - regBuff: can register user buffer
  //
  // Outputs:
  //   - nChannels: number of channels (hence SMs) to be used.
  //
  // InOut:
  //   - collCostTable: collective cost table, generated by NCCL core, containing algo|proto|time entries for collType.
  //                    NCCL core sets ignored algo/proto cost table entries to -1.0 (NCCL_ALGO_PROTO_IGNORE).
  //
  // If getCollInfo() does not return ncclSuccess, NCCL will fall back to the
  // default tuning for the given collective.
  // Also, the plugin is allowed to not set any output, or set only the
  // algorithm and protocol, but not only the algorithm or only the protocol.
  // Unset fields will be set automatically by NCCL.
  ncclResult_t (*getCollInfo)(void* context, ncclFunc_t collType, size_t nBytes,
                              int numPipeOps, float** collCostTable, int numAlgo, int numProto,
                              int regBuff, int* nChannels);

  // Terminates the plugin and cleans up any resources that the plugin allocated.
  // context: tuner context object
  ncclResult_t (*finalize)(void* context);
} ncclTuner_v5_t;

#endif
